package cn.dagteam.springboot.mongodb.starter.listener;

import java.beans.IntrospectionException;
import java.beans.PropertyDescriptor;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.*;

import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.mapping.DBRef;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.ReflectionUtils.FieldFilter;

import cn.dagteam.springboot.mongodb.starter.Cascade;
import cn.dagteam.springboot.mongodb.starter.Cascade.CascadeType;
import lombok.AllArgsConstructor;

@AllArgsConstructor
public class CascadeSaveLinkCallback implements ReflectionUtils.FieldCallback {

    private Object source;

    private MongoTemplate mongoTemplate;

    @Override
    public void doWith(Field field) throws IllegalArgumentException, IllegalAccessException {
        try {
            PropertyDescriptor pd = new PropertyDescriptor(field.getName(), source.getClass());
            Object value = ReflectionUtils.invokeMethod(pd.getReadMethod(), source);
            if (value != null) {
                Cascade cascade = field.getAnnotation(Cascade.class);
                CascadeType[] types = cascade.value();
                Arrays.stream(types).filter(p -> CascadeType.ALL.equals(p) || CascadeType.SAVE.equals(p)).findAny().ifPresent(p -> {
                    if (value instanceof Collection) {
                        Collection<?> cols = (Collection<?>) value;
                        cols.forEach(col -> saveLink(col));
                    } else {
                        saveLink(value);
                    }
                });
            }
        } catch (IntrospectionException e) {
            throw new RuntimeException("不是javabean的对象", e);
        }
    }

    @SuppressWarnings("unchecked")
    private void saveLink(Object value) {
        if (Objects.isNull(value)) return;

        Class<? extends Object> valueClass = value.getClass();
        Field f = org.springframework.data.util.ReflectionUtils.findField(valueClass, new FieldFilter() {

            @Override
            public boolean matches(Field field) {
                Class<?> type = field.getType();
                // 如果是集合类，则取第一个范式的类型
                if (Collection.class.isAssignableFrom(type)) {
                    type = getParameterizedType(field, 0);
                }
                return type.equals(source.getClass()) && field.isAnnotationPresent(DBRef.class);
            }
        });
        if (f != null) {
            try {
                PropertyDescriptor wpd = new PropertyDescriptor(f.getName(), valueClass);
                Object object = ReflectionUtils.invokeMethod(wpd.getReadMethod(), value);
                if (object == null) {
                    if (List.class.isAssignableFrom(f.getType())) {
                        object = new ArrayList<>();
                    } else if (Set.class.isAssignableFrom(f.getType())) {
                        object = new HashSet<>();
                    }
                }
                if (object instanceof List) {
                    List<Object> cols = (List<Object>) object;
                    boolean exist = cols.parallelStream().filter(p -> p.equals(source)).findAny().isPresent();
                    if (!exist) {
                        cols.add(source);
                    }
                } else if (object instanceof Set) {
                    Set<Object> cols = (Set<Object>) object;
                    cols.add(source);
                } else {
                    ReflectionUtils.invokeMethod(wpd.getWriteMethod(), value, source);
                }
            } catch (Exception e) {
                throw new RuntimeException("不是javabean的对象", e);
            }
            mongoTemplate.save(value);
        }
    }

    public static Class<?> getParameterizedType(Field field, int index) {
        Type genericType = field.getGenericType();
        // 如果是泛型参数的类型
        if (genericType != null && genericType instanceof ParameterizedType) {
            ParameterizedType pt = (ParameterizedType) genericType;
            //得到泛型里的class类型对象
            return (Class<?>) pt.getActualTypeArguments()[index];
        }
        return null;
    }
}
